import torch
import numpy as np

from dada.optimizer import DADA, UGM, WDA, DoG, Prodigy

from dada.model.model_runner import ModelRunner
from dada.model.polynomial_feasible.poly_feasible_model import PolynomialFeasibleModel
from dada.utils import *


class PolynomialFeasibleRunner(ModelRunner):

    def __init__(self, params):
        self.num_polyhedron = params['num_polyhedron']
        if self.num_polyhedron is None:
            raise ValueError('num_polyhedron is None')

        self.radius = params['radius']
        if self.radius is None:
            raise ValueError('radius is None')

        self.q_list = params['q_list']
        if self.q_list is None:
            raise ValueError('q_list is None')

        super(PolynomialFeasibleRunner, self).__init__(params)

    def run(self, iterations, model_name, save_plot, plots_directory):
        random_normal = np.random.multivariate_normal(mean=[0] * self.vector_size, cov=np.identity(self.vector_size))
        optimal_point = self.radius * (random_normal / np.linalg.norm(random_normal))
        a_matrix, b_matrix = PolynomialFeasibleModel.generate_function_variables(
            self.vector_size,
            self.num_polyhedron,
            optimal_point)

        params = [Param(names=["q", "n", "d"], values=[q, self.num_polyhedron, self.vector_size]) for q in self.q_list]
        value_distances_per_param = {}
        d_estimation_error_per_param = {}

        optimizers = []

        for param in params:
            print(param)
            q = param.get_param("q")
            n = param.get_param("n")
            d = param.get_param("d")
            init = torch.ones(d, requires_grad=True, dtype=torch.double)
            d0 = np.linalg.norm(optimal_point - init.clone().detach().numpy())

            # Dual Averaging Method
            da_model = PolynomialFeasibleModel(d, q, n, a_matrix, b_matrix, init_point=init)
            da_optimizer = WDA(da_model.params(), d0=d0)

            # GD With Line Search Method
            gd_line_search_model = PolynomialFeasibleModel(d, q, n, a_matrix, b_matrix, init_point=init)
            gd_line_search_optimizer = UGM(gd_line_search_model.params())

            # DoG Method
            dog_model = PolynomialFeasibleModel(d, q, n, a_matrix, b_matrix, init_point=init)
            dog_optimizer = DoG(dog_model.params())

            # Prodigy Method
            prodigy_model = PolynomialFeasibleModel(d, q, n, a_matrix, b_matrix, init_point=init)
            prodigy_optimizer = Prodigy(prodigy_model.params())

            # DADA Method
            dada_model = PolynomialFeasibleModel(d, q, n, a_matrix, b_matrix, init_point=init)
            dada_optimizer = DADA(dada_model.params())

            optimizers = [
                [da_optimizer, da_model],
                [gd_line_search_optimizer, gd_line_search_model],
                [dog_optimizer, dog_model],
                [prodigy_optimizer, prodigy_model],
                [dada_optimizer, dada_model]
            ]

            d_estimation_error, value_distances = run_different_opts(optimizers, iterations, optimal_point, log_per=500)
            value_distances_per_param[param] = value_distances
            d_estimation_error_per_param[param] = d_estimation_error

        plot_optimizers_result(optimizers, params, value_distances_per_param, d_estimation_error_per_param,
                               model_name=model_name, save=save_plot, plots_directory=plots_directory,
                               mark_every=(iterations // 10))
